/*
 * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 */

#define GN_ONE_PASS_RUN_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \
  void group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##_run(         \
      const Group_norm_nhwc_##PASS_NAME##_params &params, const dim3 &grid, cudaStream_t stream)

#define GN_ONE_PASS_RUN_FUNCTION(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)             \
  GN_ONE_PASS_RUN_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) {            \
    auto kernel =                                                                                                      \
        group_norm_nhwc_##PASS_NAME##_one_pass_kernel<Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK>;  \
                                                                                                                       \
    const Group_norm_nhwc_##PASS_NAME##_params *params_ = &params;                                                     \
    if (grid.x > 1) {                                                                                                  \
      CHECK_CUDA(cudaLaunchCooperativeKernel((const void *)kernel, grid, dim3(THREADS_PER_BLOCK), (void **)&params_,   \
                                             0, stream));                                                              \
                                                                                                                       \
    } else {                                                                                                           \
      CHECK_CUDA(cudaLaunchKernel((const void *)kernel, grid, dim3(THREADS_PER_BLOCK), (void **)&params_, 0, stream)); \
    }                                                                                                                  \
                                                                                                                       \
    CHECK_CUDA(cudaGetLastError());                                                                                    \
  }

//////////////////////////////////////////////////////////////////////////////////////////////////

#define GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \
                                                PASS_NAME)                                                     \
  int group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##_blocks_per_sm()

#define GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)  \
  GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME(Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) { \
    auto kernel =                                                                                                     \
        group_norm_nhwc_##PASS_NAME##_one_pass_kernel<Traits, ACTS_PER_BLOCK, CHANNELS_PER_GROUP, THREADS_PER_BLOCK>; \
                                                                                                                      \
    int blocks_per_sm = 0;                                                                                            \
    CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel, THREADS_PER_BLOCK, 0));          \
                                                                                                                      \
    CHECK_CUDA(cudaGetLastError());                                                                                   \
    return blocks_per_sm;                                                                                             \
  }

//////////////////////////////////////////////////////////////////////////////////////////////////

#define GN_ONE_PASS_(FUNCTION, Traits, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \
  FUNCTION(Traits, 512, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);               \
  FUNCTION(Traits, 256, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);               \
  FUNCTION(Traits, 128, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);               \
  FUNCTION(Traits, 64, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);

#define GN_ONE_PASS_RUN_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)                     \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Fp16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Bf16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);

#define GN_ONE_PASS_RUN_DECLARATION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)                         \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Fp16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Bf16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_RUN_FUNCTION_NAME, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);

#define GN_ONE_PASS_BLOCKS_PER_SM_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)                     \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Fp16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Bf16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME); \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);

#define GN_ONE_PASS_BLOCKS_PER_SM_DECLARATION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)             \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp32IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \
               PASS_NAME);                                                                                  \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp32IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \
               PASS_NAME);                                                                                  \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp32IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \
               PASS_NAME);                                                                                  \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \
               PASS_NAME);                                                                                  \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \
               PASS_NAME);                                                                                  \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Fp16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \
               PASS_NAME);                                                                                  \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Bf16IOFp16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \
               PASS_NAME);                                                                                  \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Bf16IOBf16W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, \
               PASS_NAME);                                                                                  \
  GN_ONE_PASS_(GN_ONE_PASS_BLOCKS_PER_SM_FUNCTION_NAME, Bf16IOFp32W, CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME);

#define GN_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME) \
  GN_ONE_PASS_RUN_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)   \
  GN_ONE_PASS_BLOCKS_PER_SM_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, PASS_NAME)

#define GN_FWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) \
  GN_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, fwd)

#define GN_BWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) \
  GN_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK, bwd)

#define GN_FWD_BWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK) \
  GN_FWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK)           \
  GN_BWD_ONE_PASS_DEFINITION(CHANNELS_PER_GROUP, THREADS_PER_BLOCK)

////////////////////////////////////////////////////////////////////////////////////////////////////

#define GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, HW_THRESHOLD, ACTS_PER_BLOCK,          \
                               CHANNELS_PER_GROUP, PASS_NAME)                                                    \
  if (params.hw >= HW_THRESHOLD && params.channels_per_group == CHANNELS_PER_GROUP &&                            \
      params.precision == PrecisionMode::PRECISION) {                                                            \
    function =                                                                                                   \
        group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##FUNC_POSTFIX; \
  } else

#define GN_SELECTION_STATEMENT_WITH_CPG_LIMIT(function, Traits, PRECISION, FUNC_POSTFIX, HW_THRESHOLD, ACTS_PER_BLOCK, \
                                              CHANNELS_PER_GROUP, PASS_NAME, LIMIT_CPG)                                \
  if (params.hw >= HW_THRESHOLD && params.channels_per_group == CHANNELS_PER_GROUP &&                                  \
      params.precision == PrecisionMode::PRECISION && CHANNELS_PER_GROUP >= LIMIT_CPG) {                               \
    function =                                                                                                         \
        group_norm_nhwc_##PASS_NAME##_one_pass_##CHANNELS_PER_GROUP##_##ACTS_PER_BLOCK##_##Traits##FUNC_POSTFIX;       \
  } else

#define GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Traits, PRECISION, CHANNELS_PER_GROUP,        \
                                                                    FUNC_POSTFIX, function, PASS_NAME)            \
  GN_SELECTION_STATEMENT_WITH_CPG_LIMIT(function, Traits, PRECISION, FUNC_POSTFIX, 1024, 128, CHANNELS_PER_GROUP, \
                                        PASS_NAME, 80)                                                            \
  GN_SELECTION_STATEMENT_WITH_CPG_LIMIT(function, Traits, PRECISION, FUNC_POSTFIX, 256, 128, CHANNELS_PER_GROUP,  \
                                        PASS_NAME, 160)                                                           \
  GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 512, 512, CHANNELS_PER_GROUP, PASS_NAME)      \
  GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 256, 256, CHANNELS_PER_GROUP, PASS_NAME)      \
  GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 128, 128, CHANNELS_PER_GROUP, PASS_NAME)      \
  GN_SELECTION_STATEMENT(function, Traits, PRECISION, FUNC_POSTFIX, 0, 64, CHANNELS_PER_GROUP, PASS_NAME)

#define GN_FWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(CHANNELS_PER_GROUP, FUNC_POSTFIX, function) \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp16W, FP32IOFP16W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, fwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOBf16W, FP32IOBF16W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, fwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp32W, FP32IOFP32W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, fwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp16W, FP16IOFP16W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, fwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOBf16W, FP16IOBF16W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, fwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp32W, FP16IOFP32W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, fwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp16W, BF16IOFP16W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, fwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOBf16W, BF16IOBF16W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, fwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp32W, BF16IOFP32W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, fwd)

#define GN_BWD_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(CHANNELS_PER_GROUP, FUNC_POSTFIX, function) \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp16W, FP32IOFP16W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, bwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOBf16W, FP32IOBF16W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, bwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp32IOFp32W, FP32IOFP32W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, bwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp16W, FP16IOFP16W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, bwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOBf16W, FP16IOBF16W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, bwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Fp16IOFp32W, FP16IOFP32W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, bwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp16W, BF16IOFP16W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, bwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOBf16W, BF16IOBF16W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, bwd)                          \
  GN_SELECTION_STATEMENT_HW_THRESHOLD_ACTS_PER_BLOCK_DISPATCH(Bf16IOFp32W, BF16IOFP32W, CHANNELS_PER_GROUP,         \
                                                              FUNC_POSTFIX, function, bwd)

////////////////////////////////////////////////////////////////////////////////////////////////////

#define GN_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP, PASS_NAME)                    \
  GN_ONE_PASS_RUN_DECLARATION(CHANNELS_PER_GROUP, /* dummy value */ 0, PASS_NAME) \
  GN_ONE_PASS_BLOCKS_PER_SM_DECLARATION(CHANNELS_PER_GROUP, /* dummy value */ 0, PASS_NAME)

#define GN_FWD_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP) GN_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP, fwd)

#define GN_BWD_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP) GN_ONE_PASS_DECLARATION(CHANNELS_PER_GROUP, bwd)

////////////////////////////////////////////////////////////////////////////////////////////////////

#define CALL_TWO_PASS_KERNEL(Kernel, Precision)               \
  if (params.channels_per_block == 320) {                     \
    Kernel<Precision, 160><<<grid, 160, 0, stream>>>(params); \
  } else if (params.channels_per_block == 280) {              \
    Kernel<Precision, 140><<<grid, 140, 0, stream>>>(params); \
  } else if (params.channels_per_block == 208) {              \
    Kernel<Precision, 140><<<grid, 104, 0, stream>>>(params); \
  } else if (params.channels_per_block == 240) {              \
    Kernel<Precision, 120><<<grid, 120, 0, stream>>>(params); \
  } else if (params.channels_per_block == 512) {              \
    Kernel<Precision, 256><<<grid, 256, 0, stream>>>(params); \
  } else if (params.channels_per_block == 448) {              \
    Kernel<Precision, 448><<<grid, 224, 0, stream>>>(params); \
  } else if (params.channels_per_block == 384) {              \
    Kernel<Precision, 192><<<grid, 192, 0, stream>>>(params); \
  } else if (params.channels_per_block == 256) {              \
    Kernel<Precision, 128><<<grid, 128, 0, stream>>>(params); \
  } else if (params.channels_per_block == 128) {              \
    Kernel<Precision, 64><<<grid, 64, 0, stream>>>(params);   \
  } else if (params.channels_per_block == 336) {              \
    Kernel<Precision, 168><<<grid, 168, 0, stream>>>(params); \
  } else if (params.channels_per_block == 392) {              \
    Kernel<Precision, 196><<<grid, 196, 0, stream>>>(params); \
  } else {                                                    \
    assert(false);                                            \
  }

////////////////////////////////////////////////////////////////////////////////////////////////////
